from __future__ import print_function
# 使python2.x的print语法与python3.x的print规则一样

import argparse
import os
import random
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torchvision.utils as vutils
from torch.autograd import Variable
import time
import numpy as np
from numpy import *
from data_loader.dataset import train_dataset
from models.u_net import UNet
from models.seg_net import Segnet
from models.fcn import FCN8s, VGGNet

#####################################################################################
# ToDo:是否使用多通道Capsule
#####################################################################################
# from models.capsule import CapsuleNet, CapsuleLoss
from models.multi_capsule import CapsuleNet, CapsuleLoss

from utils.metrics import Evaluator

parser = argparse.ArgumentParser(description='Training a RS_Semantic_Segmentation model')
parser.add_argument('--batch_size', type=int, default=4, help='equivalent to instance normalization with batch_size=1')
parser.add_argument('--input_nc', type=int, default=3)
parser.add_argument('--output_nc', type=int, default=1)
parser.add_argument('--niter', type=int, default=10, help='number of epochs to train for')
parser.add_argument('--lr', type=float, default=0.0001, help='learning rate, default=0.0002')
parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5')
parser.add_argument('--cuda', type=bool,default=True, help='enables cuda. default=True')
parser.add_argument('--manual_seed', type=int, default=2021, help='manual seed') # 手动随机种子
parser.add_argument('--num_workers', type=int, default=0, help='how many threads of cpu to use while loading data')
parser.add_argument('--size_w', type=int, default=256, help='scale image to this size')
parser.add_argument('--size_h', type=int, default=256, help='scale image to this size')
parser.add_argument('--flip', type=int, default=0, help='1 for flipping image randomly, 0 for not')
parser.add_argument('--net', type=str, default='', help='path to pre-trained network')
parser.add_argument('--data_path', default='./data/train', help='path to training images')
parser.add_argument('--val_data_path', default='./data/val', help='path to validation images')

#####################################################################################
# ToDo:choose the model which will be trained
#####################################################################################
#parser.add_argument('--outf', default='./checkpoint/Unet', help='folder to output images and model checkpoints')
#parser.add_argument('--outf', default='./checkpoint/Segnet', help='folder to output images and model checkpoints')
#parser.add_argument('--outf', default='./checkpoint/FCN', help='folder to output images and model checkpoints')
#parser.add_argument('--outf', default='./checkpoint/Capsule', help='folder to output images and model checkpoints')
parser.add_argument('--outf', default='./checkpoint/MultiCapsule', help='folder to output images and model checkpoints')

parser.add_argument('--save_epoch', default=1, help='number of epoch to save parameters')
parser.add_argument('--test_step', default=300, help='number of step to eval model')
parser.add_argument('--log_step', default=1, help='number of step to write log')
parser.add_argument('--num_GPU', default=1, help='number of GPU')
opt = parser.parse_args()
try:
    os.makedirs(opt.outf)
    os.makedirs(opt.outf + '/model/')
except OSError:
    pass
if opt.manual_seed is None:
    opt.manual_seed = random.randint(1, 10000)
random.seed(opt.manual_seed)
torch.manual_seed(opt.manual_seed)
cudnn.benchmark = True

train_datatset_ = train_dataset(opt.data_path, opt.size_w, opt.size_h, opt.flip)
train_loader = torch.utils.data.DataLoader(dataset=train_datatset_, batch_size=opt.batch_size, shuffle=True,
                                           num_workers=opt.num_workers)

def weights_init(m):
    class_name = m.__class__.__name__
    if class_name.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
        #####################################################################################
        # ToDo:如果是胶囊,不要执行 m.bias.data.fill_(0)
        #####################################################################################
        # m.bias.data.fill_(0)

    elif class_name.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

#####################################################################################
# ToDo:init the model
#####################################################################################
# net = UNet(opt.input_nc, opt.output_nc)
# net = Segnet(opt.input_nc, opt.output_nc)
# net = FCN8s(pretrained_net=VGGNet(pretrained=False),n_class=opt.output_nc)
net = CapsuleNet(num_parts=5)

if opt.net != '':
    net.load_state_dict(torch.load(opt.net))

#####################################################################################
# ToDo:FCN不用以下方式初始化参数
#####################################################################################
else:
    net.apply(weights_init)

if opt.cuda:
    net.cuda()
if opt.num_GPU > 1:
    net = nn.DataParallel(net)


###########   LOSS & OPTIMIZER   ##########
# criterion = nn.BCELoss()
#####################################################################################
# ToDo:choose the capsule loss function
#####################################################################################
criterion = CapsuleLoss(height=256,width=256)

optimizer = torch.optim.Adam(net.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))

###########   GLOBAL VARIABLES   ###########
initial_image = torch.FloatTensor(opt.batch_size, opt.input_nc, opt.size_w, opt.size_h)
semantic_image = torch.FloatTensor(opt.batch_size, opt.input_nc, opt.size_w, opt.size_h)
initial_image = Variable(initial_image)
semantic_image = Variable(semantic_image)

if opt.cuda:
    initial_image = initial_image.cuda()
    semantic_image = semantic_image.cuda()


def caculate_miou_pa(model,data_path=opt.val_data_path):
    """
    计算miou和pa
    """
    model.eval()
    eval=Evaluator(2)
    #batch_size = opt.batch_size
    batch_size=2

    datatset_=train_dataset(data_path, opt.size_w, opt.size_h, opt.flip)
    data_loader = torch.utils.data.DataLoader(dataset=datatset_,
                                              batch_size=batch_size,
                                              shuffle=True,
                                              num_workers=opt.num_workers)
    loader = iter(data_loader)
    for i in range(0, datatset_.__len__(), batch_size):
        initial_image_, semantic_image_, name = loader.next()

        if opt.cuda:
            initial_image_ = initial_image_.cuda()
            semantic_image_ = semantic_image_.cuda()

        #####################################################################################
        # ToDo:注意分辨模型是否是胶囊网络
        #####################################################################################
        # semantic_image_pred = model(initial_image_)
        part_map, semantic_image_pred = model(initial_image_)

        semantic_image_pred = torch.cat((semantic_image_pred, semantic_image_pred, semantic_image_pred), dim=1)

        semantic_image_ = semantic_image_.view(-1)
        semantic_image_pred = semantic_image_pred.view(-1)

        eval.update_matrix(semantic_image_.data.cpu().numpy().astype(np.uint8),
                           (semantic_image_pred+torch.tensor(0.5)).data.cpu().numpy().astype(np.uint8))
    # 计算MIoU
    MIoU = eval.Mean_Intersection_over_Union()
    # 计算PA
    PA = eval.Pixel_Accuracy()
    """
    model.eval()不开启BN和Dropout
    只有torch.no_grad()可以关闭梯度计算,可用于节省显存
    为了确保梯度传播的准确,在loss.backward()前执行optimizer.zero_grad()即可
    """
    model.train()
    return MIoU, PA


if __name__ == '__main__':

    loss_log = []
    MIoU_log = []
    PA_log = []

    #####################################################################################
    # ToDo:setting the dir of log
    #####################################################################################
    # log = open('./checkpoint/Unet/train_Unet_log.txt', 'w')
    # log = open('./checkpoint/Segnet/train_Segnet_log.txt', 'w')
    # log = open('./checkpoint/FCN/train_FCN_log.txt', 'w')
    # log = open('./checkpoint/Capsule/train_Capsule_log.txt', 'w')
    log = open('./checkpoint/MultiCapsule/train_Capsule_log.txt', 'w')

    start = time.time()
    net.train()
    print("start training...")
    for epoch in range(1, opt.niter+1):
        loader = iter(train_loader)
        for i in range(0, train_datatset_.__len__(), opt.batch_size):
            initial_image_, semantic_image_, name = loader.next()

            initial_image.resize_(initial_image_.size()).copy_(initial_image_)
            semantic_image.resize_(semantic_image_.size()).copy_(semantic_image_)
            # print("semantic_image:", semantic_image.size()) #torch.Size([4, 3, 256, 256])
            # dataloader已经把标注的值统一转为0或者1
            # print(set(semantic_image.cpu().numpy().reshape(-1)))

            #####################################################################################
            # ToDo:注意推理的模型
            #####################################################################################
            # semantic_image_pred = net(initial_image)
            part_map, semantic_image_pred = net(initial_image)

            #print("semantic_image_pred:",semantic_image_pred.size()) #torch.Size([4, 1, 256, 256])

            semantic_image_pred=torch.cat((semantic_image_pred,semantic_image_pred,semantic_image_pred),dim=1)
            #print("semantic_image_pred:", semantic_image_pred.size()) #torch.Size([4, 3, 256, 256])

            semantic_image = semantic_image.view(-1)
            semantic_image_pred = semantic_image_pred.view(-1)

            #####################################################################################
            # ToDo:check the type of loss{ BCELoss or CapsuleLoss }
            #####################################################################################
            # loss = criterion(semantic_image_pred, semantic_image)
            loss = criterion(semantic_image_pred, part_map, semantic_image)
            loss_log.append(loss.item())

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            ########### Logging ##########
            if i % opt.log_step == 0:
                print('[%d/%d][%d/%d] Loss: %.4f' %
                      (epoch, opt.niter, i, len(train_loader) * opt.batch_size, loss.item()))
                log.write('[%d/%d][%d/%d] Loss: %.4f' %
                          (epoch, opt.niter, i, len(train_loader) * opt.batch_size, loss.item()))
            if i % opt.test_step == 0:
                MIoU, PA = caculate_miou_pa(net,data_path=opt.val_data_path)
                MIoU_log.append(MIoU)
                PA_log.append(PA)
                print("MIoU:{},PA:{}".format(MIoU,PA))
                vutils.save_image(semantic_image_pred.data.reshape(-1,3,256,256), opt.outf + '/fake_samples_epoch_%03d_%03d.png' % (epoch, i),normalize=True)

        if epoch % opt.save_epoch == 0:
            torch.save(net.state_dict(), '%s/model/netG_%s.pth' % (opt.outf, str(epoch)))

    end = time.time()
    torch.save(net.state_dict(), '%s/model/netG_final.pth' % opt.outf)

    loss_log = np.array(loss_log)
    np.save('%s/model/loss.npy' % opt.outf, loss_log)
    MIoU_log = np.array(MIoU_log)
    np.save('%s/model/MIoU.npy' % opt.outf, MIoU_log)
    PA_log = np.array(PA_log)
    np.save('%s/model/PA.npy' % opt.outf, PA_log)

    print('Program processed ', end - start, 's, ', (end - start)/60, 'min, ', (end - start)/3600, 'h')
    log.close()